import sys
import pickle
import numpy as np
import statistics
import scipy.stats as stats
from sklearn import metrics
import pandas as pd

#Unpickle import data
files = sys.argv[1:]
unpickle =[[] for file in range(len(files))]

for file_num, file in enumerate(files):
	f = open(file,'rb')
	unpickle[file_num] = pickle.load(f)
	f.close()

onset_dict = unpickle[0]
time = unpickle[1]
event_list = unpickle[2]

#Check to see if there is only one event. If so, ignore mean and variance calcs
#Also trim 'nans' from Peak Onset Times and Max Peak Times from onset_dict
event_num_check = 0
temp_onset_times = []
temp_peak_times = []

for item_num, item in enumerate(onset_dict['Peak Onset Times']):
	if item != 'nan':
		event_num_check += 1
		temp_onset_times.append(onset_dict['Peak Onset Times'][item_num])
		temp_peak_times.append(onset_dict['Max Peak Times'][item_num])

#Calculate average peak onset time
if event_num_check > 1:
	avg_onset_time = statistics.fmean(temp_onset_times)
	sem_onset_time = stats.sem(temp_onset_times)
	std_onset_time = statistics.stdev(temp_onset_times)
else:
	avg_onset_time = 'nan'
	sem_onset_time = 'nan'
	std_onset_time = 'nan'

#Calculate time from onset to peak
onset_to_peak = []
temp_onset_to_peak = []

for item_num, item in enumerate(onset_dict['Peak Onset Times']):
	if item != 'nan':
		onset_to_peak.append(onset_dict['Max Peak Times'][item_num] - item)
		temp_onset_to_peak.append(onset_dict['Max Peak Times'][item_num] - item)
	else:
		onset_to_peak.append('nan')

if event_num_check > 1:
	avg_onset_to_peak = statistics.fmean(temp_onset_to_peak)
	sem_onset_to_peak = stats.sem(temp_onset_to_peak)

else:
	avg_onset_to_peak = 'nan'
	sem_onset_to_peak = 'nan'
	
avg_max_peak_amp = statistics.fmean(onset_dict['Max Peak Amps'])
sem_max_peak_amp = stats.sem(onset_dict['Max Peak Amps'])
avg_min_valley_amp = statistics.fmean(onset_dict['Min Valley Amps'])
sem_min_valley_amp = stats.sem(onset_dict['Min Valley Amps'])


onset_calc_dict = {'Avg Onset Time': avg_onset_time, 'SEM Onset Time': sem_onset_time, 'Avg Onset to Peak Time': avg_onset_to_peak, 'SEM Onset to Peak Time': sem_onset_to_peak, 'Avg Max Peak Amp': avg_max_peak_amp, 'SEM Max Peak Amp': sem_max_peak_amp, 'Avg Min Valley Amp': avg_min_valley_amp, 'SEM Min Valley Amp': sem_min_valley_amp}
onset_dict['Onset to Peak Times'] = onset_to_peak

#Calculate zscores of onset times for outlier flagging
if event_num_check > 1:
	onset_time_zscores = []

	for item in onset_dict['Peak Onset Times']:
		if item != 'nan':
			onset_time_zscores.append((item - avg_onset_time)/std_onset_time)
		else:
			onset_time_zscores.append('nan')

#Flag outlier name and index
	outlier_flags_names = []
	outlier_flags_index = []
	outlier_flags_zscore = []

	for item_num, item in enumerate(onset_time_zscores):
		if item != 'nan':
			if abs(item) >= 2:
				outlier_flags_names.append('Event ' + str(item_num + 1))
				outlier_flags_index.append(item_num)
				outlier_flags_zscore.append(item)
else:
	outlier_flags_names = []
	outlier_flags_index = []
	outlier_flags_zscore = []
		
outlier_dict = {'Outlier Names': outlier_flags_names, 'Outlier Index': outlier_flags_index, 'Outlier Zscore': outlier_flags_zscore}

#Calculate AUC for -1s to 5s of positive portions of event traces
auc_list = []
auc_dict = {}
auc_event_list = [[] for i in range(len(event_list))]

for item_num, item in enumerate(event_list):
	for point_num, point in enumerate(item):
		if time[point_num] >= -1 and time[point_num] <= 5 and point >= 0:
			auc_event_list[item_num].append(point)
		if time[point_num] >= -1 and time[point_num] <= 5 and point < 0:
			auc_event_list[item_num].append(0)

for item in auc_event_list:
	auc_list.append(metrics.auc(time[0:len(item)], item))
	
auc_dict['AUC'] = auc_list
auc_dict['Mean AUC'] = statistics.fmean(auc_list)
auc_dict['SEM AUC'] = stats.sem(auc_list)
					
#Dump new dicts to pickles
pickle.dump(onset_dict, open('onset_dict.pkl','wb'))
pickle.dump(onset_calc_dict, open('onset_calc_dict.pkl','wb'))
pickle.dump(outlier_dict, open('outlier_dict.pkl','wb'))
pickle.dump(event_num_check, open('event_num_check.pkl','wb'))
pickle.dump(auc_dict, open('auc_dict.pkl','wb'))